import argparse
import json
import pickle
import os
from collections import namedtuple
import datetime
import torch
import torch.nn.functional as F
import yaml
import numpy as np

from generic.model_util import to_np

ICEHOCKEY_ACTIONS = ['assist',
                     'block',
                     'carry',
                     'check',
                     'controlledbreakout',
                     'controlledentryagainst',
                     'dumpin',
                     'dumpinagainst',
                     'dumpout',
                     'faceoff',
                     'goal',
                     'icing',
                     'lpr',
                     'offside',
                     'pass',
                     'pass1timer',
                     'penalty',
                     'pressure',
                     'pscarry',
                     'pscheck',
                     'pslpr',
                     'pspuckprotection',
                     'puckprotection',
                     'reception',
                     'receptionprevention',
                     'shot',
                     'shot1timer',
                     'socarry',
                     'socheck',
                     'sogoal',
                     'solpr',
                     'sopuckprotection',
                     'soshot']

SOCCER_ACTIONS  = ['caught-offside',
                   'pass_from_fk',
                   'cross_from_fk',
                   'pass_from_corner',
                   'cross_from_corner',
                   'cross',
                   'throw-in',
                   'through-ball',
                   'switch-of-play',
                   'long-ball',
                   'simple-pass',
                   'take-on_drible',
                   'skill',
                   'tackle',
                   'interception',
                   'aerial-challenge',
                   'clearance',
                   'ball-recovery',
                   'offside-provoked',
                   'own-goal',
                   'penalty_shot',
                   'fk_shot',
                   'corner_shot',
                   'standard_shot',
                   'blocked_shot',
                   'save',
                   'claim',
                   'punch',
                   'pick-up',
                   'smother',
                   'keeper-sweeper',
                   'penalty_save',
                   'penalising_foul',
                   'minor_foul',
                   'penalty_obtained',
                   'dangerous_foul',
                   'dangerous_foul_obtained',
                   'run_with_ball',
                   'dispossessed',
                   'bad-touch',
                   'miss',
                   'error',
                   'goal'
                   ]

ICEHOCKEY_GAME_FEATURES = ['xAdjCoord',
                           'yAdjCoord',
                           'scoreDifferential',
                           'Penalty',
                           'duration',
                           'velocity_x',
                           'velocity_y',
                           'time_remained',
                           'event_outcome',
                           'home',
                           'away',
                           'angel2gate'
                           ]

SOCCER_GAME_FEATURES = ['angle',
                        'distance',
                        'duration',
                        'gameTimeRemain',
                        'periodId',
                        'interrupted',
                        'manPower',
                        'outcome',
                        'scoreDiff',
                        'x',
                        'y',
                        'velocity_x',
                        'velocity_y',
                        'distance_x',
                        'distance_y',
                        'home',
                        'away',
                        'distance_to_goal'
                        ]

MANPOWER_FEATURE = {
    'shortHanded': 1,
    'evenStrength': 0,
    'powerPlay': -1
}

OUTCOME_FEATURE = {
    'successful': 1,
    'undetermined': 0,
    'failed': -1
}

HA_FEATURE = {
    'home': 1,
    'away': 0
}

Transition = namedtuple('Transition', ('state_action', 'trace', 'next_state_action',
                                       'next_trace', 'reward_h', 'reward_a', 'reward_n', 'pid', 'done'))


def load_event_data(data_path):
    with open(data_path, 'rb') as f:
        event_data = pickle.load(file=f)
    return event_data


def read_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("config_file", help="path to configs file")
    parser.add_argument("-t", "--train_flag", help="if training",
                        dest="TRAIN_FLAG",
                        default='1', required=False)
    parser.add_argument("-f", "--fixed_checkpoint", help="The check point label to be loaded",
                        dest="CHECK_POINT",
                        default=None, required=False)
    parser.add_argument("-d", "--debug_mode", help="whether to use the debug mode",
                        dest="DEBUG_MODE",
                        default=False, required=False)
    parser.add_argument("-m", "--learn_mode", help="the mode of learning ",
                        dest="LEARN_MODE",
                        default='normal', required=False)
    parser.add_argument("-lde", "--load_dqn_episode", help="the training episode",
                        dest="LOAD_EPISODE",
                        default=0, required=False)
    parser.add_argument("-ldd", "--load_dqn_date_label", help="the date of learned dqn",
                        dest="LOAD_DATE",
                        default=datetime.datetime.now().strftime('%b-%d-%Y-%H:%M'), required=False)
    parser.add_argument("-l", "--log_file", help="log file", dest="LOG_FILE_PATH", default=None, required=False)
    args = parser.parse_args()
    return args


def load_config(args=None):
    assert os.path.exists(args.config_file), "Invalid configs file {0}".format(args.config_file)
    with open(args.config_file) as reader:
        config = yaml.safe_load(reader)
    return config, args.DEBUG_MODE, args.LOG_FILE_PATH


def pad_sequence(sequence, max_length):
    pad_sample = np.asarray([0 for i in range(len(sequence[0]))])
    return np.asarray(sequence + [pad_sample for i in range(max_length - len(sequence))])


def build_trace_mask(trace, max_trace_length):
    batch_size = len(trace)
    trace_mask = np.zeros([batch_size, max_trace_length])
    for bid in range(len(trace)):
        # for i in range(trace[bid]):
        trace_mask[bid][trace[bid] - 1] = 1
    return trace_mask


def read_data(source_data_dir, file_name, output_team_ids=False):
    with open(source_data_dir + file_name) as f:
        data = json.load(f)

    events = data.get('events')
    game_id = data.get('gameId')
    if output_team_ids:
        return events, game_id, list(data['rosters'].keys())
    else:
        return events, game_id


def judge_home_away(home_away_game_ids, teamId, gameId):
    game_home_away_info = home_away_game_ids[gameId]
    for key in game_home_away_info.keys():
        if game_home_away_info[key] == int(teamId):
            if key == 'team1id':
                return 'home'
            elif key == 'team2id':
                return 'away'
            else:
                raise ValueError("Unknown team {0}".format(key))
    return ValueError("Unknown teamId {0} in the game {1}".format(teamId, gameId))


def print_game_events_info(transition_game, team_values_all, apply_rnn, team_uncertainties_all, sports,
                           sanity_check_msg=None, write_file=None):
    if sports == 'ice-hockey':
        data_means, data_stds = read_feature_mean_scale(data_dir='../icehockey-data/')
        candidate_action_all = ICEHOCKEY_ACTIONS
    elif sports == 'soccer':
        data_means, data_stds = read_feature_mean_scale(data_dir='../soccer-data/')
        candidate_action_all = SOCCER_ACTIONS
    else:
        raise ValueError("Unknown sports: {0}".format(sports))
    assert len(transition_game) == len(team_values_all)
    for idx in range(0, len(transition_game)):
        if apply_rnn:
            state_action_data = transition_game[idx].state_action[transition_game[idx].trace - 1]
            reward_h = transition_game[idx].reward_h[transition_game[idx].trace - 1]
            reward_a = transition_game[idx].reward_a[transition_game[idx].trace - 1]
            reward_n = transition_game[idx].reward_n[transition_game[idx].trace - 1]
            # print(reward_h, reward_a, reward_n)
            state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
                                                        data_means=data_means,
                                                        data_stds=data_stds,
                                                        sanity_check_msg=sanity_check_msg,
                                                        sports=sports)
            # sanity_check_msg=sanity_check_msg)
        else:
            state_action_origin = reverse_standard_data(state_action_data=to_np(transition_game[idx].state_action),
                                                        data_means=data_means,
                                                        data_stds=data_stds,
                                                        sports=sports)
        home_away = 'home' if state_action_origin['home'] > state_action_origin['away'] else 'away'
        action = None
        max_action_label = 0
        for candidate_action in candidate_action_all:  # check which action is performed
            if state_action_origin[candidate_action] > max_action_label:
                max_action_label = state_action_origin[candidate_action]
                action = candidate_action
        print("Event Index: {0:d}, Team: {1:s}, "
              "Action: {2:s}, Manpower: {3:2.0f}, "
              "Score Difference: {4:2.0f}, "
              "X Coord: {5:2.2f}, Y Coord: {6:2.2f},"
              "Reward_h: {7:2.0f}, Reward_a: {8:2.0f}, Reward_n: {9:2.0f}, "
              " Q values: {10:1.2f}/{11:1.2f}, Uncertainty:{12:2.4f}".format(
            idx,
            home_away,
            action,
            state_action_origin['Penalty'] if 'Penalty' in state_action_origin.keys() else float('inf'),
            state_action_origin['scoreDifferential'] if 'scoreDifferential' in state_action_origin.keys() else float(
                'inf'),
            state_action_origin['xAdjCoord'] if 'xAdjCoord' in state_action_origin.keys() else float('inf'),
            state_action_origin['yAdjCoord'] if 'yAdjCoord' in state_action_origin.keys() else float('inf'),
            reward_h,
            reward_a,
            reward_n,
            team_values_all[idx][0],
            team_values_all[idx][1],
            team_uncertainties_all[idx] if team_uncertainties_all is not None else float('inf'),
        ), flush=True, file=write_file
        )


def read_feature_max_min(data_dir):
    with open(data_dir + '/feature_max.json', 'r') as fp:
        data_max_dict = json.load(fp)
    with open(data_dir + '/feature_min.json', 'r') as fp:
        data_min_dict = json.load(fp)
    return data_max_dict, data_min_dict


def read_feature_mean_scale(data_dir):
    with open(data_dir + '/feature_mean.json', 'r') as fp:
        data_means_dict = json.load(fp)
    with open(data_dir + '/feature_std.json', 'r') as fp:
        data_stds_dict = json.load(fp)
    return data_means_dict, data_stds_dict

    # data_means = []
    # with open(data_dir+'/feature_mean.txt', 'r') as f:
    #     data_lines = f.readlines()
    # for data_line in data_lines:
    #     read_values = []
    #     read_strs = data_line.replace('[', '').replace(']', '').replace('\n', '').split(' ')
    #     for read_str in read_strs:
    #         if len(read_str) > 0:
    #             read_values.append(float(read_str))
    #     data_means += read_values
    #
    # data_stds = []
    # with open(data_dir+'/feature_scale.txt', 'r') as f:
    #     data_lines = f.readlines()
    # for data_line in data_lines:
    #     read_values = []
    #     read_strs = data_line.replace('[', '').replace(']', '').replace('\n', '').split(' ')
    #     for read_str in read_strs:
    #         if len(read_str) > 0:
    #             read_values.append(float(read_str))
    #     data_stds += read_values
    #
    # return data_means, data_stds


def reverse_standard_data(state_action_data, data_means, data_stds, sports, sanity_check_msg=None):
    if sanity_check_msg is None:
        if sports == 'ice-hockey':
            all_features = ICEHOCKEY_GAME_FEATURES + ICEHOCKEY_ACTIONS
        elif sports == 'soccer':
            all_features = SOCCER_GAME_FEATURES + SOCCER_ACTIONS
        else:
            raise ValueError("Unknown sports: {0}".format(sports))
    elif 'location' in sanity_check_msg and 'ha' in sanity_check_msg:
        if sports == 'ice-hockey':
            all_features = ICEHOCKEY_GAME_FEATURES[:2] + ICEHOCKEY_GAME_FEATURES[9:11] + ICEHOCKEY_ACTIONS
    elif 'sd' in sanity_check_msg and 'md' in sanity_check_msg and 'ha' in sanity_check_msg:
        if sports == 'ice-hockey':
            all_features = ICEHOCKEY_GAME_FEATURES[2:4] + ICEHOCKEY_GAME_FEATURES[7:8] + ICEHOCKEY_GAME_FEATURES[9:11] + ICEHOCKEY_ACTIONS
    else:
        raise ValueError("Unknown sanity_check_msg".format(sanity_check_msg))

    reverse_data = {}
    for i in range(len(all_features)):
        feature_name = all_features[i]
        feature_standard_value = state_action_data[i]
        feature_value = feature_standard_value * data_stds[feature_name] + data_means[feature_name]
        reverse_data.update({feature_name: feature_value})
    return reverse_data


def calculate_location_bin_expectation(locations, values, bin_x=1, bin_y=1):
    x_min = -100
    x_max = 100
    y_min = -45
    y_max = 45
    x_dim = int((x_max - x_min) / bin_x)
    y_dim = int((y_max - y_min) / bin_y)
    bin_sum_store = np.zeros([y_dim, x_dim])
    bin_num_store = np.zeros([y_dim, x_dim])
    assert len(locations) == len(values)
    for idx in range(len(locations)):
        x_float = locations[idx][0] - x_min
        x_idx = int((x_float - x_float % bin_x) / bin_x)
        y_float = locations[idx][1] - y_min
        y_idx = int((y_float - y_float % bin_y) / bin_y)
        # tmp = bin_sum_store[x_idx][y_idx]
        if len(values.shape) == 2:
            bin_sum_store[y_dim - y_idx - 1][x_idx] += np.sum(values[idx])  # reverse the y axis
            bin_num_store[y_dim - y_idx - 1][x_idx] += len(values[idx])
        else:
            bin_sum_store[y_dim - y_idx - 1][x_idx] += values[idx]  # reverse the y axis
            bin_num_store[y_dim - y_idx - 1][x_idx] += 1
    for i in range(len(bin_num_store)):
        for j in range(len(bin_num_store[i])):
            if bin_num_store[i][j] == 0:
                bin_num_store[i][j] = 1
                bin_sum_store[i][j] = -1
    bin_expected_values = bin_sum_store / bin_num_store

    return bin_expected_values, bin_num_store


def summarize_feature_bin_samples(feature_exp_values, interested_feature_names, interested_feature_split_dict):
    bin_feature_value_store = {}
    feature_label_meanings = [{} for interested_feature_name in interested_feature_names]
    for feature_exp_value in feature_exp_values:
        bin_feature_list = []
        for i in range(len(interested_feature_names)):
            feature_name = interested_feature_names[i]
            feature_value = feature_exp_value[1][i]
            split_points = interested_feature_split_dict[feature_name]
            for split_point_idx in range(len(split_points)):
                split_point = split_points[split_point_idx]
                if feature_value < split_point:
                    feature_label_meaning = '{0}<{1}<{2}'.format(split_points[split_point_idx - 1],
                                                                 feature_name,
                                                                 split_point)
                    feature_label_meanings[i].update({split_point_idx - 1: feature_label_meaning})
                    bin_feature_list.append(str(split_point_idx - 1))
                    break
        bin_feature_str = '@'.join(bin_feature_list)
        if bin_feature_str in bin_feature_value_store:
            # try:
            #     bin_feature_value_store[bin_feature_str] = \
            #         np.concatenate((bin_feature_value_store[bin_feature_str], feature_exp_value[0]))
            # except:
            #     print("debugging")
            bin_feature_value_store[bin_feature_str].append(feature_exp_value[0])
        else:
            bin_feature_value_store[bin_feature_str] = [feature_exp_value[0]]

    return bin_feature_value_store, feature_label_meanings


def summarize_location_bin_samples(locations, values, num_tau, bin_x=1, bin_y=1,
                                   empirical_uncertainty_from_samples_flag=True):
    x_min = -100
    x_max = 100
    y_min = -50
    y_max = 50
    x_dim = int((x_max - x_min) / bin_x)
    y_dim = int((y_max - y_min) / bin_y)
    bin_feature_location_store = {}
    assert len(locations) == len(values)
    global_avg_value = np.mean(values)
    for idx in range(len(locations)):
        x_float = locations[idx][0] - x_min
        x_idx = int((x_float - x_float % bin_x) / bin_x)
        y_float = locations[idx][1] - y_min
        y_idx = int((y_float - y_float % bin_y) / bin_y)
        # tmp = bin_sum_store[x_idx][y_idx]
        bin_key = "({0},{1})".format(y_dim - y_idx - 1, x_idx)
        if bin_key in bin_feature_location_store.keys():
            if len(values.shape) == 2:
                bin_feature_location_store[bin_key] = np.concatenate((bin_feature_location_store[bin_key], values[idx]))
            else:
                bin_feature_location_store[bin_key] = np.concatenate(
                    (bin_feature_location_store[bin_key], np.asarray([values[idx]])))
        else:
            if len(values.shape) == 2:
                bin_feature_location_store.update({bin_key: values[idx]})
            else:
                bin_feature_location_store.update({bin_key: np.asarray([values[idx]])})

    bin_expect_values = np.zeros([y_dim, x_dim])
    bin_std_values = np.zeros([y_dim, x_dim])
    bin_entropy_values = np.zeros([y_dim, x_dim])
    bin_num_values = np.zeros([y_dim, x_dim])
    for i in range(len(bin_num_values)):
        for j in range(len(bin_num_values[i])):
            bin_key = "({0},{1})".format(i, j)
            if bin_key in bin_feature_location_store:
                if empirical_uncertainty_from_samples_flag:
                    if len(bin_feature_location_store[bin_key]) > 1:
                        bin_std_values[i][j] = np.std(bin_feature_location_store[bin_key])
                        bin_entropy_values[i][j] = samples2entropy(np.expand_dims(bin_feature_location_store[bin_key],
                                                                                  axis=(0, 1)))[0][0]
                    else:
                        bin_std_values[i][j] = 0
                        bin_entropy_values[i][j] = 0
                bin_expect_values[i][j] = np.mean(bin_feature_location_store[bin_key])
                bin_num_values[i][j] = len(bin_feature_location_store[bin_key])
            else:
                bin_num_values[i][j] = 0
                bin_expect_values[i][j] = global_avg_value
                if empirical_uncertainty_from_samples_flag:
                    bin_std_values[i][j] = -1
                    bin_entropy_values[i][j] = -0.1

    return bin_feature_location_store, bin_expect_values, bin_std_values, bin_entropy_values, bin_num_values


def reward_look_ahead(transition_all, begin_idx, apply_rnn, gamma):
    """
    look ahead to the future cumulative rewards
    """
    h_cumu_rewards = 0
    a_cumu_rewards = 0
    n_cumu_rewards = 0
    counter = 0
    for i in range(begin_idx, len(transition_all)):
        if apply_rnn:
            reward_h = transition_all[i].reward_h[transition_all[i].next_trace - 1]
            reward_a = transition_all[i].reward_a[transition_all[i].next_trace - 1]
            reward_n = transition_all[i].reward_n[transition_all[i].next_trace - 1]
        else:
            reward_h = transition_all[i].reward_h
            reward_a = transition_all[i].reward_a
            reward_n = transition_all[i].reward_n
        h_cumu_rewards += (gamma ** counter) * reward_h
        a_cumu_rewards += (gamma ** counter) * reward_a
        n_cumu_rewards += (gamma ** counter) * reward_n
        counter += 1
        if reward_h + reward_a + reward_n > 0:
            return i + 1, h_cumu_rewards, a_cumu_rewards, n_cumu_rewards


def normalization_01(x):
    return (x - np.min(x)) / (np.max(x) - np.min(x))


def divide_dataset_according2date(all_data_files, train_rate, sports, if_split, if_return_split=False):
    if if_split:
        training_files_num = int(len(all_data_files) * train_rate)
        validate_files_num = int(len(all_data_files) * (1 - train_rate) / 2)
        testing_files_num = len(all_data_files) - training_files_num - validate_files_num
        from datetime import datetime as dt
        # play_off_date = dt.strptime('2019-03-01', "%Y-%m-%d")
        if sports == 'ice-hockey':
            game_dates_dir = '../icehockey-data/game_dates_2018_2019.json'
        elif sports == 'soccer':
            game_dates_dir = '../soccer-data/game_dates_2017_2018.json'
        else:
            raise ValueError("Unknown sports: {0}".format(sports))
        with open(game_dates_dir, 'r') as file:
            game_dates = json.load(file)
        game_dates_dict = {}
        for game_date_info in game_dates:
            game_date = dt.strptime(game_date_info['date'], "%Y-%m-%d")
            if game_date in game_dates_dict:
                game_dates_dict[game_date].append(game_date_info['gameid'])
            else:
                game_dates_dict.update({game_date: [game_date_info['gameid']]})
        training_game_ids = []
        validate_game_ids = []
        testing_game_ids = []
        split_dates = [None, None, None, sorted(game_dates_dict.keys())[-1]]
        for game_date in sorted(game_dates_dict.keys()):
            for game_id in game_dates_dict[game_date]:
                if str(game_id) in all_data_files:
                    if len(training_game_ids) < training_files_num:
                        if split_dates[0] is None:
                            split_dates[0] = game_date
                        training_game_ids.append(str(game_id))
                    elif len(validate_game_ids) < validate_files_num:
                        if split_dates[1] is None:
                            split_dates[1] = game_date
                        validate_game_ids.append(str(game_id))
                    else:
                        if split_dates[2] is None:
                            split_dates[2] = game_date
                        testing_game_ids.append(str(game_id))
        if if_return_split:
            return training_game_ids, validate_game_ids, testing_game_ids, split_dates
        else:
            return training_game_ids, validate_game_ids, testing_game_ids
    else:
        if if_return_split:
            return all_data_files, all_data_files, all_data_files, None
        else:
            return all_data_files, all_data_files, all_data_files


def aggregate_compute_distance_by_team(sample_home_bin_fea_value_store,
                                       output_home_bin_fea_value_store,
                                       sample_home_fea_label_meanings,
                                       sample_away_bin_fea_value_store,
                                       output_away_bin_fea_value_store,
                                       sample_away_fea_label_meanings,
                                       interest_fea_label_pair=[],
                                       risk_level=None):
    interest_feature_value_store = []
    interest_feature_output_store = []
    for label in sample_home_bin_fea_value_store.keys():
        # i, j, k = list(map(int, label.split('@')))
        label_indices = list(map(int, label.split('@')))
        if label_indices[interest_fea_label_pair[0]] == interest_fea_label_pair[1]:
            # sample_home_bin_feature_value_store[label]
            interest_feature_value_store += sample_home_bin_fea_value_store[label]
            interest_feature_output_store += output_home_bin_fea_value_store[label]
    for label in sample_away_bin_fea_value_store.keys():
        label_indices = list(map(int, label.split('@')))
        if label_indices[interest_fea_label_pair[0]] == interest_fea_label_pair[1]:
            # sample_home_bin_feature_value_store[label]
            interest_feature_value_store += sample_away_bin_fea_value_store[label]
            interest_feature_output_store += output_away_bin_fea_value_store[label]
    return np.asarray(interest_feature_value_store).flatten(), np.asarray(interest_feature_output_store).flatten()


def compute_distance_by_teams(context_feature_split_dict,
                              sample_home_feature_label_meanings, predicted_exp_home_value_matrix,
                              predicted_std_home_value_matrix, outcome_std_home_value_matrix,
                              outcome_exp_home_value_matrix, home_num_matrix,
                              sample_away_feature_label_meanings, predicted_exp_away_value_matrix,
                              predicted_std_away_value_matrix, outcome_std_away_value_matrix,
                              outcome_exp_away_value_matrix, away_num_matrix,
                              distance_measure):
    normalized_home_empirical_risks = []
    normalized_away_empirical_risks = []
    normalized_home_std_risks = []
    normalized_away_std_risks = []
    home_num_sum = np.sum(home_num_matrix)
    away_num_sum = np.sum(away_num_matrix)
    all_record_msg = ''
    for i in range(len(context_feature_split_dict['Penalty']) - 1):
        for j in range(len(context_feature_split_dict['scoreDifferential']) - 1):
            for k in range(len(context_feature_split_dict['time_remained']) - 1):
                feature_label_meanings = None
                if i in sample_home_feature_label_meanings[0].keys() \
                        and j in sample_home_feature_label_meanings[1].keys() \
                        and k in sample_home_feature_label_meanings[2].keys() and \
                        predicted_exp_home_value_matrix[i, j, k] > 0:
                    if distance_measure == 'mse':
                        home_num_ratio = home_num_matrix[i, j, k] / home_num_sum
                        home_exp_diff = (predicted_exp_home_value_matrix[i, j, k] -
                                         outcome_exp_home_value_matrix[i, j, k]) ** 2
                        normalized_home_empirical_risks.append(home_num_ratio * home_exp_diff)
                        home_std_diff = (predicted_std_home_value_matrix[i, j, k] -
                                         outcome_std_home_value_matrix[i, j, k]) ** 2
                        normalized_home_std_risks.append(home_num_ratio * home_std_diff)
                    else:
                        raise ValueError("Unknown distance_measure {0}".format(distance_measure))
                    feature_label_meanings = sample_home_feature_label_meanings

                else:
                    home_exp_diff = 'None'
                    home_std_diff = 'None'

                if i in sample_away_feature_label_meanings[0].keys() \
                        and j in sample_away_feature_label_meanings[1].keys() \
                        and k in sample_away_feature_label_meanings[2].keys() and \
                        predicted_exp_away_value_matrix[i, j, k] > 0:
                    if distance_measure == 'mse':
                        away_num_ratio = away_num_matrix[i, j, k] / away_num_sum
                        away_exp_diff = (predicted_exp_away_value_matrix[i, j, k] -
                                         outcome_exp_away_value_matrix[i, j, k]) ** 2
                        normalized_away_empirical_risks.append(away_num_ratio * away_exp_diff)
                        away_std_diff = (predicted_std_away_value_matrix[i, j, k] -
                                         outcome_std_away_value_matrix[i, j, k]) ** 2
                        normalized_away_std_risks.append(away_num_ratio * away_std_diff)
                        feature_label_meanings = sample_away_feature_label_meanings
                    else:
                        raise ValueError("Unknown distance_measure {0}".format(distance_measure))
                else:
                    away_exp_diff = 'None'
                    away_std_diff = 'None'
                if feature_label_meanings is not None:
                    # record_msg = "Context: [{0}, {1}, {2}], " \
                    #              "Home Exp:{3}, Home Emp:{4}, Home Num:{5} " \
                    #              "Away Exp:{6}, Away Emp:{7}, Away Num:{8}". \
                    #     format(feature_label_meanings[0][i],
                    #            feature_label_meanings[1][j],
                    #            feature_label_meanings[2][k],
                    #            predicted_exp_home_value_matrix[i, j, k],
                    #            outcome_exp_home_value_matrix[i, j, k],
                    #            home_num_matrix[i, j, k],
                    #            predicted_exp_away_value_matrix[i, j, k],
                    #            outcome_exp_away_value_matrix[i, j, k],
                    #            away_num_matrix[i, j, k],
                    #            )
                    record_msg = "Context: [{0}-{1}-{2}], " \
                                 "Home Exp model:{3}, Home Exp outcome:{4}, Home Exp diff:{5}, " \
                                 "Home std model:{6}, Home std outcome:{7}, Home std diff:{8}, Home Num:{9} " \
                                 "Away Exp model:{10}, Away Emp outcome:{11}, Away Exp diff:{12}, " \
                                 "Away std model:{13}, Away std outcome:{14}, Away std diff:{15}, Away Num:{16} ". \
                        format(feature_label_meanings[0][i],
                               feature_label_meanings[1][j],
                               feature_label_meanings[2][k],
                               predicted_exp_home_value_matrix[i, j, k],
                               outcome_exp_home_value_matrix[i, j, k],
                               home_exp_diff,
                               predicted_std_home_value_matrix[i, j, k],
                               outcome_std_home_value_matrix[i, j, k],
                               home_std_diff,
                               home_num_matrix[i, j, k],
                               predicted_exp_away_value_matrix[i, j, k],
                               outcome_exp_away_value_matrix[i, j, k],
                               away_exp_diff,
                               predicted_std_away_value_matrix[i, j, k],
                               outcome_std_away_value_matrix[i, j, k],
                               away_std_diff,
                               away_num_matrix[i, j, k],
                               )
                    all_record_msg += record_msg + '\n'

    return all_record_msg, normalized_home_empirical_risks, normalized_away_empirical_risks, \
           normalized_home_std_risks, normalized_away_std_risks


class HistoryScoreCache:

    def __init__(self, capacity=1):
        self.capacity = capacity
        self.reset()

    def push(self, stuff):
        """stuff is float."""
        if len(self.memory) < self.capacity:
            self.memory.append(stuff)
        else:
            self.memory = self.memory[1:] + [stuff]

    def get_avg(self):
        return np.mean(np.array(self.memory))

    def reset(self):
        self.memory = []

    def __len__(self):
        return len(self.memory)


class QValueDiscretization:
    def __init__(self, all_q_values, split_num, discret_mode='GapSplit', gap=(0, 1)):
        self.all_q_values = all_q_values
        self.split_num = split_num
        self.gap = gap
        self.mode = discret_mode
        self.fit_split_values()

    def fit_split_values(self):
        if self.mode == 'GapSplit':
            split_gap = float(self.gap[1] - self.gap[0]) / self.split_num
            self.split_values = [(i + 1) * split_gap for i in range(self.split_num)]
        elif self.mode == 'NumberSplit':
            self.split_values = []
            all_q_values_sorted = sorted(self.all_q_values)
            split_gap = int(len(all_q_values_sorted) / self.split_num)
            for i in range(self.split_num):
                if i == 0:
                    continue
                split_value = all_q_values_sorted[i * split_gap]
                self.split_values.append(split_value)

    def discretize_q_values(self, q_values):
        labels = []
        for q_value in q_values:
            label = 0
            for split_value in self.split_values:
                if q_value < split_value:
                    break
                else:
                    label += 1
            labels.append(label)

        return labels


def handle_gda_features(fitting_target,
                        all_latent_features,
                        all_trace_length,
                        sanity_check_msg,
                        max_trace_length,
                        split_state_action_resnet,
                        apply_history=False
                        ):
    if fitting_target == 'QValues':
        handle_all_latent_features_cat = []
        for i in range(len(all_latent_features)):
            latent_features_all = all_latent_features[i]
            # latent_features_all = latent_features_all[:, :4]
            if apply_history:
                latent_features = []
                for j in range(all_trace_length[i]):
                    latent_features.append(latent_features_all[j])
                pad_sample = np.asarray([0 for i in range(len(latent_features_all[0]))])
                latent_features = np.asarray([pad_sample for i in range(max_trace_length - all_trace_length[i])]
                                             + latent_features)
                latent_features = latent_features.flatten()
            else:
                latent_features = latent_features_all[all_trace_length[i] - 1]
            handle_all_latent_features_cat.append(latent_features)
        return np.stack(handle_all_latent_features_cat, axis=0)
    elif fitting_target == 'Actions' and split_state_action_resnet:
        handle_all_latent_features_cat = []
        for i in range(len(all_latent_features)):
            if apply_history:
                raise ValueError("Not yet supp")
            latent_features_all = all_latent_features[i]
            state_features = latent_features_all[all_trace_length[i] - 1, :-len(ICEHOCKEY_ACTIONS)]
            handle_all_latent_features_cat.append(state_features)
            # if 'location' in sanity_check_msg and 'ha' in sanity_check_msg:
            #     latent_feature = latent_feature
            #     # latent_feature = torch.reshape(latent_feature, shape=[])
            #     # state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
            #     #                                             data_means=data_means,
            #     #                                             data_stds=data_stds, )
            #     state_features = latent_feature[all_trace_length[i] - 1, :-len(ACTIONS)]
            #     # if torch.sum(state_features) == 0:
            #     #     print('debug')
            #     handle_all_latent_features_cat.append(state_features)
            #     # s_a_data = np.concatenate([s_a_data[:, :2], s_a_data[:, 9:11], s_a_data[:, -len(ACTIONS):]], axis=1)
            #
            # elif 'sd' in sanity_check_msg and 'md' in sanity_check_msg and 'ha' in sanity_check_msg:
            #     raise ValueError('still working')
            #     # s_a_data = np.concatenate([s_a_data[:, 2:4], s_a_data[:, 7:8],
            #     #                            s_a_data[:, 9:11], s_a_data[:, -len(ACTIONS):]], axis=1)
        return np.stack(handle_all_latent_features_cat, axis=0)
    else:
        raise ValueError("Unknown fitting target {0}".format(fitting_target))


def label_visiting_shrink(all_labels, action_visited_flag):
    """some actions have never been performed before, we should eliminate their labels to prevent empty samples"""
    label_shrink_num = []
    shrink_count = 0
    for flag in action_visited_flag:
        label_shrink_num.append(shrink_count)
        if not flag:
            shrink_count += 1
    all_shrink_labels = []
    for label in all_labels:
        all_shrink_labels.append(label - label_shrink_num[label])
    all_shrink_labels = np.asarray(all_shrink_labels)
    return all_shrink_labels, np.max(all_shrink_labels) + 1


def samples2entropy(output_samples, split_gap=0.1):
    """the range of sample must be 0 to 1, it is what we know by definition."""
    output_samples_shape = output_samples.shape
    output_entropys = np.zeros(shape=[output_samples_shape[0], output_samples_shape[1]])
    splitting_values = [i for i in np.arange(0, 1, split_gap)] + [1]

    for i in range(output_samples_shape[0]):
        for j in range(output_samples_shape[1]):
            counts = [0.0 for k in range(len(splitting_values))]
            for sample in output_samples[i][j]:
                splitting_value_idx = int(sample / split_gap)
                counts[splitting_value_idx] += 1.0
                # for splitting_value_idx in range(len(splitting_values)):
                #     if sample <= splitting_values[splitting_value_idx + 1]:
                #         counts[splitting_value_idx] += 1.0
                #         break
            sample_entropy = entropy(torch.tensor(counts), dim=0).item()
            output_entropys[i][j] = sample_entropy

    return output_entropys


def entropy(logits, dim=1):
    p = F.softmax(logits, dim=dim)
    logp = F.log_softmax(logits, dim=dim)
    plogp = p * logp
    entropy = -torch.sum(plogp, dim=dim)
    return entropy


def logsumexp(logits):
    return torch.logsumexp(logits, dim=1, keepdim=False)


def read_features_within_events(game_label, source_data_dir, feature_name_list, sports):
    if sports == 'ice-hockey':
        with open(source_data_dir + str(game_label) + '-playsequence-wpoi.json') as f:
            data = json.load(f)
    elif sports == 'soccer':
        with open(source_data_dir + str(game_label)+'.json') as f:
            data = json.load(f)
    else:
        raise ValueError("Unknown sports: {0}".format(sports))
    events = data.get('events')
    features_all = []
    for event in events:
        feature_values = {}
        for feature_name in feature_name_list:
            try:
                value = str(event.get(feature_name))
            except:
                value = event.get(feature_name)
            feature_values.update({feature_name: value})
        features_all.append(feature_values)

    return features_all


# if __name__ == "__main__":
#
#     tmp_dir = '../icehockey-data/NHL_players_game_summary_201819.csv'
#     with open(tmp_dir) as f:
#         tmp = json.load(f)
#     print(tmp)
#
#     source_data_dir = '../icehockey-data/2018-2019/'
#     saved_data_dir = '../icehockey-data/saved_data_2018_2019/'
#     team_names_dir = '../icehockey-data/team_name_2018_19.json'
#     interested_game_id = ['16694']
#
#     with open(team_names_dir) as f:
#         team_names_info_list = json.load(f)
#
#     team_names_dict = {}
#     for team_names_info in team_names_info_list:
#         team_names_dict.update({str(team_names_info['teamid']): team_names_info['name']})
#
#     for game_file in os.listdir(source_data_dir):
#         _, game_id, team_ids = read_data(source_data_dir=source_data_dir,
#                                          file_name=game_file,
#                                          output_team_ids=True)
#         if game_id in interested_game_id:
#             print('{0} v.s. {1}'.format(team_names_dict[team_ids[0]],
#                                         team_names_dict[team_ids[1]]))


def read_player_info():
    player_id_name = {}
    with open('../icehockey-data/player_info_2018_2019.json', 'rb') as player_file:
        player_info_dict = json.load(player_file)
    for player_id in player_info_dict:
        player_id_name.update({player_id: [player_info_dict[player_id]['first_name'] + " " +
                                           player_info_dict[player_id]['last_name'],
                                           player_info_dict[player_id]['position'],
                                           player_info_dict[player_id]['teamName']
                                           ]})
    return player_id_name


def get_nf_input(agent, state_action, trace, apply_history, values_cond=None, if_normalized=True,
                 sanity_check_msg=None):
    standard_data_maxs = []
    standard_data_mins = []
    if agent.sports == 'ice-hockey':
        all_features = ICEHOCKEY_GAME_FEATURES + ICEHOCKEY_ACTIONS
        actions = ICEHOCKEY_ACTIONS
    elif agent.sports == 'soccer':
        all_features = SOCCER_GAME_FEATURES + SOCCER_ACTIONS
        actions = SOCCER_ACTIONS
    else:
        raise ValueError("Unknown sports {0}".format(agent.sports))
    for feature in all_features:
        standard_data_maxs.append((agent.data_maxs[feature] - agent.data_means[feature]) / agent.data_stds[feature])
        standard_data_mins.append((agent.data_mins[feature] - agent.data_means[feature]) / agent.data_stds[feature])
    standard_data_maxs = torch.tensor(standard_data_maxs).to(agent.device)
    standard_data_mins = torch.tensor(standard_data_mins).to(agent.device)
    if sanity_check_msg is None:
        pass
    elif 'location' in sanity_check_msg and 'ha' in sanity_check_msg:
        standard_data_maxs_location = standard_data_maxs[:2]
        # standard_data_maxs_home_away = standard_data_maxs[9:11]
        # standard_data_maxs_action = standard_data_maxs[12:]
        # standard_data_maxs = torch.cat([standard_data_maxs_location,
        #                                 standard_data_maxs_home_away,
        #                                 standard_data_maxs_action], dim=0)
        standard_data_maxs = standard_data_maxs_location
        standard_data_mins_location = standard_data_mins[:2]
        # standard_data_mins_home_away = standard_data_mins[9:11]
        # standard_data_mins_action = standard_data_mins[12:]
        # standard_data_mins = torch.cat([standard_data_mins_location,
        #                                 standard_data_mins_home_away,
        #                                 standard_data_mins_action], dim=0)
        standard_data_mins = standard_data_mins_location
    else:
        raise ValueError("Unknown sanity_check_msg".format(sanity_check_msg))

    batch_size = len(state_action)
    if apply_history:
        if agent.maf_cond_act or agent.maf_cond_value:
            tgt_data = torch.reshape(torch.stack(state_action)[:, :, :], shape=(batch_size, -1))
            tgt_cond = None
        else:
            tgt_data = torch.reshape(torch.stack(state_action)[:, :-len(actions), :], shape=(batch_size, -1))
            tgt_cond = torch.reshape(torch.stack(state_action)[:, -len(actions):, :], shape=(batch_size, -1))
    else:
        tgt_data = []
        if agent.maf_cond_act or agent.maf_cond_value:
            tgt_cond = []
        else:
            tgt_cond = None
        for i in range(batch_size):
            if agent.maf_cond_act:
                tgt_data.append(state_action[i][trace[i] - 1, :][:-len(actions)])
                tgt_cond.append(state_action[i][trace[i] - 1, :][-len(actions):])
            else:
                tgt_data.append(state_action[i][trace[i] - 1, :])
        tgt_data = torch.stack(tgt_data)
        if agent.maf_cond_act:
            tgt_cond = torch.stack(tgt_cond)
    if if_normalized:
        standard_data_maxs = standard_data_maxs.unsqueeze(0).repeat(batch_size, 1)
        standard_data_mins = standard_data_mins.unsqueeze(0).repeat(batch_size, 1)
        if agent.maf_cond_act:
            tgt_data = (tgt_data - standard_data_mins[:, :-len(actions)]) / (standard_data_maxs[:, :-len(actions)] -
                                                                             standard_data_mins[:, :-len(actions)])
            tgt_cond = (tgt_cond - standard_data_mins[:, -len(actions):]) / (standard_data_maxs[:, -len(actions):] -
                                                                             standard_data_mins[:, -len(actions):])
            # print(tgt_data[0])
            # print(tgt_cond[0])
        else:
            tgt_data = (tgt_data - standard_data_mins) / (standard_data_maxs - standard_data_mins)

    if agent.maf_cond_value and agent.maf_cond_act:
        tgt_cond = torch.cat([tgt_cond, values_cond], dim=1)
    elif agent.maf_cond_value:
        tgt_cond = values_cond
    # print(tmp[0])
    # print(tmp[10])
    return tgt_data, tgt_cond


def get_game_time(agent, game_name, sanity_check_msg):
    output_game, transition_game = agent.compute_values_by_game(game_name, sanity_check_msg)
    if agent.sports == 'ice-hockey':
        data_means, data_stds = read_feature_mean_scale(data_dir='../icehockey-data/')
        game_time_all = []
        for i in range(len(transition_game)):
            state_action_data = transition_game[i].state_action[transition_game[i].trace - 1]
            state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
                                                        data_means=data_means,
                                                        data_stds=data_stds,
                                                        sports=agent.sports, )
            time_remained = state_action_origin['time_remained']
            game_time_all.append(3600 - time_remained)
    elif agent.sports == 'soccer':
        game_time_all = []
        with open(os.path.join(agent.source_data_dir, game_name+'.json')) as f:
            data = json.load(f)
        events = data.get('events')
        for event in events:
            game_minutes = event.get('min')
            game_seconds = event.get('sec')
            game_time = game_minutes * 60 + game_seconds
            game_time_all.append(game_time)
        game_time_all.sort(reverse=False)
    else:
        raise ValueError("Unknown sports: {0}".format(agent.sports))
    return game_time_all, transition_game, output_game
